import torch
from torch.autograd import Variable
from torch.utils.data import DataLoader, TensorDataset, random_split, RandomSampler
import numpy as np
import copy

#bootstrap_bs = 1
#BatchNum = 100
#torchseed = 1

def GradientVariance(test_model, inputs, labels, N_train):
    #torch.manual_seed(torchseed)
    model = copy.deepcopy(test_model)
    d = sum(p.numel() for p in model.parameters())
    #FullGradient = torch.zeros(1, d)
    Gradients = torch.zeros(N_train, d)
    criterion = torch.nn.MSELoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=1)
    '''
    # Converting inputs and labels to Variable
    if torch.cuda.is_available():
        inputs = Variable(torch.from_numpy(inputs).cuda())
        labels = Variable(torch.from_numpy(labels).cuda())
    else:
        inputs = Variable(torch.from_numpy(inputs))
        labels = Variable(torch.from_numpy(labels))
    '''
    # Create DataLoader
    dataset = TensorDataset(inputs, labels)
    single_loader = DataLoader(dataset=dataset, batch_size=1)

    # Compute the full gradient
    optimizer.zero_grad()
    FullLoss = criterion(model(inputs), labels)
    FullLoss.backward()
    #FullGradientPlus = model.linear.weight.grad.clone()
    #FullGradientMinus = model.linearminus.weight.grad.clone()
    #FullGradient = torch.cat((FullGradientPlus[0], FullGradientMinus[0]), 0)
    FullGradient = torch.tensor([])
    for p in model.parameters():
        if p.requires_grad:
            FullGradient = torch.cat((FullGradient, p.grad.view(-1, 1)), 0)
    #print("The full gradient is {}".format(FullGradient))
    optimizer.zero_grad()

    # Compute the gradient for each batch
    optimizer.zero_grad()
    for idx, (x, y) in enumerate(single_loader):
    #for idx, (x, y) in enumerate(batch_loader):
        loss = criterion(model(x), y)
        loss.backward()
        #Gradients[idx] = torch.cat((model.linear.weight.grad.clone(), model.linearminus.weight.grad.clone()), 1)
        Gradient = torch.tensor([])
        for p in model.parameters():
            if p.requires_grad:
                Gradient = torch.cat((Gradient, p.grad.view(-1, 1)), 0)
        Gradients[idx,:] = Gradient.view(1, -1)
        optimizer.zero_grad()
        #print(torch.mm(torch.reshape(Gradients[idx], (d * 2, 1)), torch.reshape(Gradients[idx], (1, d * 2))).size())
        #print(torch.mm(torch.reshape(Gradients[idx], (d * 2, 1)), torch.reshape(Gradients[idx], (1, d * 2))))

    #print(Gradients)
    #print(Gradients-FullGradient)
    #print(torch.norm(Gradients-FullGradient, p=2, dim=1))
    #Variance = torch.norm(Gradients-FullGradient, p=2, dim=1).mean()
    GradientResiduals = Gradients - FullGradient.view(1, -1)
    #Variances = torch.norm(GradientResiduals, p=2, dim=1)
    Variances = torch.norm(Gradients, p=2, dim=1)
    #print("Size of Variances is {}".format(str(Variances.shape)))
    Variances = Variances ** 2
    Variance = Variances.mean()
    #print(Variance)
    #GradientResidualNormExpectation = torch.norm(GradientResiduals, p=2, dim=1)
    #GradientResidualNormExpectation = GradientResidualNormExpectation.mean()
    return Variance, torch.norm(FullGradient.view(1, -1), p=2, dim=1)